from dataclasses import dataclass
import torch
from Data_generator import get_random_problems, get_random_eval_problems


@dataclass
class Reset_State:
    problems: torch.Tensor
    # shape: (batch, node, node)

@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor
    POMO_IDX: torch.Tensor
    # shape: (batch, pomo)
    current_node: torch.Tensor = None
    # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None
    # shape: (batch, pomo, node)
    machine_time: torch.Tensor=None


class PFSPEnv:
    def __init__(self, **env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.pomo_size = env_params['pomo_size']
        self.n_jobs = env_params['job_cnt']
        self.n_mc = env_params['mc_cnt']
        self.mode = env_params['mode']
        # Const @Load_Problem
        ####################################
        self.batch_size = None
        self.BATCH_IDX = None
        self.POMO_IDX = None
        # IDX.shape: (batch, pomo)
        self.problems = None
        self.origianl_problems=None
        # shape: (batch, node, node)

        # Dynamic
        ####################################
        self.selected_count = None
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~)

        # STEP-State
        ####################################
        self.step_state = None

    def load_problems(self, batch_size, proj_type):
        self.batch_size = batch_size
        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)

        if proj_type== 'train':
            self.problems = get_random_problems(batch_size, self.n_jobs, self.n_mc, self.mode)
        else:
            self.problems = get_random_eval_problems(batch_size, self.n_jobs, self.n_mc, self.mode, seed=1235)
        # shape: (batch, job, mc)

    def load_problems_manual(self, problems):
        # problems.shape: (batch, node, node)

        self.batch_size = problems.size(0)
        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)
        self.problems = problems
        # shape: (batch, node, node)

    def reset(self):
        self.selected_count = 0
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = torch.empty((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        # shape: (batch, pomo, 0~)

        self._create_step_state()

        reward = None
        done = False
        return Reset_State(problems=self.problems), reward, done

    def _create_step_state(self):
        self.step_state = Step_State(BATCH_IDX=self.BATCH_IDX, POMO_IDX=self.POMO_IDX)
        self.step_state.ninf_mask = torch.zeros((self.batch_size, self.pomo_size, self.n_jobs))
        self.step_state.machine_time = torch.zeros((self.batch_size, self.pomo_size, 2*self.n_mc))
        # shape: (batch, pomo, node)

    def pre_step(self):
        reward = None
        done = False
        return self.step_state, reward, done

    def step(self, node_idx):
        # node_idx.shape: (batch, pomo)

        self.selected_count += 1

        self.current_node = node_idx
        # shape: (batch, pomo)
        self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
        # shape: (batch, pomo, 0~node)

        self._update_step_state()
        
        # returning values
        done = (self.selected_count == self.n_jobs)
        if done:
            reward = -self._makespan()  # Note the MINUS Sign ==> We MAXIMIZE reward
            #reward= -self.step_state.machine_time[:,:,-1:].squeeze(-1)
            # shape: (batch, pomo)
            #plot_gantt_chart(self.selected_node_list, s_t, done_t)
        else:    
            reward = None
        return self.step_state, reward, done

    def _update_step_state(self):
        self.step_state.current_node = self.current_node
        # shape: (batch, pomo)
        self.step_state.ninf_mask[self.BATCH_IDX, self.POMO_IDX, self.current_node] = float('-inf')
        #last_mc_state = self.step_state.machine_time.clone()
        #self.step_state.machine_time = self.update_machine_times(last_mc_state)
        # shape: (batch, pomo, n_mc*2)

    
    def _makespan(self):

        batch_size, sampling_size, num_jobs =  self.selected_node_list.shape
        _, num_jobs, num_machines = self.problems.shape
    
        completion_times = torch.zeros(batch_size, sampling_size, num_jobs, num_machines)
        start_times = torch.zeros(batch_size, sampling_size, num_jobs, num_machines)
    
        job_sequences_expanded = self.selected_node_list.unsqueeze(-1).expand(-1, -1, -1, num_machines)
        problems_expanded = self.problems.unsqueeze(1).expand(-1, sampling_size, -1, -1)
        ordered_processing_times = torch.gather(problems_expanded, 2, job_sequences_expanded)
    
        completion_times[:, :, 0, :] = torch.cumsum(ordered_processing_times[:, :, 0, :], dim=2)
        start_times[:, :, 0, 0] = 0
        start_times[:, :, 0, 1:] = completion_times[:, :, 0, :-1]
    
        for j in range(1, num_jobs):
            start_times[:, :, j, 0] = completion_times[:, :, j - 1, 0]
            completion_times[:, :, j, 0] = start_times[:, :, j, 0] + ordered_processing_times[:, :, j, 0]
    
            for m in range(1, num_machines):
                start_times[:, :, j, m] = torch.max(completion_times[:, :, j, m - 1], completion_times[:, :, j - 1, m])
                completion_times[:, :, j, m] = start_times[:, :, j, m] + ordered_processing_times[:, :, j, m]
    
        makespans = completion_times[:, :, -1, -1]

        return makespans #, start_times, completion_times
    

    def update_machine_times(self, machine_times):

        batch_size, pomo_size, num_mc = self.BATCH_IDX, self.POMO_IDX, self.n_mc

        selected_jobs_expanded = self.current_node.unsqueeze(-1).expand(-1, -1, self.problems.shape[2])  # (batch, pomo, num_mc)
        selected_processing_time = self.origianl_problems.gather(1, selected_jobs_expanded)  # (batch, pomo, num_mc)

        start_times = machine_times[:, :, :num_mc]  # (batch, pomo, num_mc)
        end_times = machine_times[:, :, num_mc:]   # (batch, pomo, num_mc)

        for mc in range(num_mc):
            if mc == 0:
                new_start_times = torch.max(end_times[:, :, mc], torch.zeros_like(end_times[:, :, mc]))
            else:
                new_start_times = torch.max(end_times[:, :, mc], end_times[:, :, mc - 1])
            new_end_times = new_start_times + selected_processing_time[:, :, mc]  # End time = start time + processing time

            start_times[:, :, mc] = new_start_times
            end_times[:, :, mc] = new_end_times

        updated_machine_times = torch.cat([start_times, end_times], dim=-1)  # (batch, pomo, num_mc*2)
        return updated_machine_times
